<!DOCTYPE html>
<html lang="en">
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>
    <meta name="twitter:description" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/neox/samples/llm_int8.html"/>
    <meta property="og:title" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>
    <meta property="og:description" content="Generate Text with GPT-NeoX using LLM.int8() quantization"/>

    <title>Generate Text with GPT-NeoX using LLM.int8() quantization</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../../pylit.css?v=1">
    <link rel="canonical" href="https://nn.labml.ai/neox/samples/llm_int8.html"/>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="../index.html">neox</a>
                <a class="parent" href="index.html">samples</a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/samples/llm_int8.py" target="_blank">
                    View code on Github</a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1>Generate Text with GPT-NeoX using LLM.int8() quantization</h1>
<p>This shows how to generate text from GPT-NeoX using <a href="../utils/llm_int8.html">LLM.int8() quantization</a>.</p>
<p>This needs a GPU with 24GB memory.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">15</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">16</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">17</span>
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">monit</span>
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_nn.neox.model</span> <span class="kn">import</span> <span class="n">LayerGenerator</span>
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.neox.samples.generate</span> <span class="kn">import</span> <span class="n">PROMPT</span><span class="p">,</span> <span class="n">infer</span>
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils</span> <span class="kn">import</span> <span class="n">get_tokens</span><span class="p">,</span> <span class="n">print_tokens</span>
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.neox.utils.cache</span> <span class="kn">import</span> <span class="n">get_cache</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-1'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-1'>#</a>
            </div>
            <h2>Generate text</h2>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">25</span><span class="k">def</span> <span class="nf">generate</span><span class="p">():</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-2'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-2'>#</a>
            </div>
            <p>Setup <a href="../utils/cache.html">cache</a> to cache intermediate key/value pairs for faster generation </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">31</span>    <span class="n">cache</span> <span class="o">=</span> <span class="n">get_cache</span><span class="p">()</span>
<span class="lineno">32</span>    <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;use_cache&#39;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-3'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-3'>#</a>
            </div>
            <p>Device </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">35</span>    <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-4'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-4'>#</a>
            </div>
            <p>Load layers in float16 into CPU. We convert the layers to int8 later, because doing that on the fly after loading layers to GPU causes CUDA memory fragmentation (about 3GB memory can get lost due to fragmentation). </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">40</span>    <span class="n">layer_generator</span> <span class="o">=</span> <span class="n">LayerGenerator</span><span class="p">(</span><span class="n">is_clone_layers</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">41</span>                                     <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float16</span><span class="p">,</span>
<span class="lineno">42</span>                                     <span class="n">device</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span>
<span class="lineno">43</span>                                     <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">44</span>                                     <span class="p">)</span>
<span class="lineno">45</span>    <span class="n">layers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">layer_generator</span><span class="o">.</span><span class="n">load</span><span class="p">())</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-5'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-5'>#</a>
            </div>
            <p>This reduces CUDA memory fragmentation </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">48</span>    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Convert to int8&#39;</span><span class="p">,</span> <span class="n">layers</span><span class="p">,</span> <span class="n">is_children_silent</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="lineno">49</span>        <span class="n">layer_generator</span><span class="o">.</span><span class="n">post_load_prepare</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span>
<span class="lineno">50</span>                                          <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">51</span>                                          <span class="n">is_llm_int8</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="lineno">52</span>                                          <span class="n">llm_int8_threshold</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span>
<span class="lineno">53</span>                                          <span class="p">)</span>
<span class="lineno">54</span>        <span class="n">layer</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-6'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-6'>#</a>
            </div>
            <p>Create <code  class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span></code>
 model </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">57</span>    <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-7'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-7'>#</a>
            </div>
            <p>Clear cache and print memory summary for debugging </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">60</span>    <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
<span class="lineno">61</span>    <span class="nb">print</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_summary</span><span class="p">())</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-8'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-8'>#</a>
            </div>
            <p>Get token ids </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">64</span>    <span class="n">ids</span> <span class="o">=</span> <span class="n">get_tokens</span><span class="p">(</span><span class="n">PROMPT</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-9'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-9'>#</a>
            </div>
            <p>Run the model. We use the <a href="generate.html"><code  class="highlight"><span></span><span class="n">infer</span></code>
</a> function defined in <a href="generate.html"><code  class="highlight"><span></span><span class="n">generate</span><span class="o">.</span><span class="n">py</span></code>
</a> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">68</span>    <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">,</span> <span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="lineno">69</span>    <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Infer&#39;</span><span class="p">):</span>
<span class="lineno">70</span>        <span class="n">next_token</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">ids</span><span class="p">,</span> <span class="n">device</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-10'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-10'>#</a>
            </div>
            <p>Append the predicted token </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">73</span>    <span class="n">ids</span> <span class="o">+=</span> <span class="p">[</span><span class="n">next_token</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-11'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-11'>#</a>
            </div>
            <p>Predict 100 tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">76</span>    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-12'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-12'>#</a>
            </div>
            <p>Set the state to use cached activations </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">78</span>        <span class="n">cache</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s1">&#39;state_ids&#39;</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-13'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-13'>#</a>
            </div>
            <p>Get next token. Note that we only feed the last token to the model because we cache the key/value pairs of previous tokens. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">81</span>        <span class="k">with</span> <span class="n">monit</span><span class="o">.</span><span class="n">section</span><span class="p">(</span><span class="s1">&#39;Infer&#39;</span><span class="p">):</span>
<span class="lineno">82</span>            <span class="n">next_token</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">[</span><span class="n">next_token</span><span class="p">],</span> <span class="n">device</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-14'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-14'>#</a>
            </div>
            <p>Append the predicted token </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">84</span>        <span class="n">ids</span> <span class="o">+=</span> <span class="p">[</span><span class="n">next_token</span><span class="p">]</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-15'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-15'>#</a>
            </div>
            <p>Print </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">86</span>        <span class="n">print_tokens</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="p">[</span><span class="n">ids</span><span class="p">])</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-16'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-16'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">90</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">91</span>    <span class="n">generate</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='footer'>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>